from typing import Union

import numpy as np

from autograd.builtins import SequenceBox
from autograd.extend import Box, primitive

from . import numpy_wrapper as anp

Box.__array_priority__ = 90.0


class ArrayBox(Box):
    __slots__ = []
    __array_priority__ = 100.0

    @primitive
    def __getitem__(A, idx):
        return A[idx]

    # Constants w.r.t float data just pass though
    shape = property(lambda self: self._value.shape)
    ndim = property(lambda self: self._value.ndim)
    size = property(lambda self: self._value.size)
    dtype = property(lambda self: self._value.dtype)
    T = property(lambda self: anp.transpose(self))

    def __array_namespace__(self, *, api_version: Union[str, None] = None):
        return anp

    def __len__(self):
        return len(self._value)

    def astype(self, *args, **kwargs):
        return anp._astype(self, *args, **kwargs)

    def __neg__(self):
        return anp.negative(self)

    def __add__(self, other):
        return anp.add(self, other)

    def __sub__(self, other):
        return anp.subtract(self, other)

    def __mul__(self, other):
        return anp.multiply(self, other)

    def __pow__(self, other):
        return anp.power(self, other)

    def __div__(self, other):
        return anp.divide(self, other)

    def __mod__(self, other):
        return anp.mod(self, other)

    def __truediv__(self, other):
        return anp.true_divide(self, other)

    def __matmul__(self, other):
        return anp.matmul(self, other)

    def __radd__(self, other):
        return anp.add(other, self)

    def __rsub__(self, other):
        return anp.subtract(other, self)

    def __rmul__(self, other):
        return anp.multiply(other, self)

    def __rpow__(self, other):
        return anp.power(other, self)

    def __rdiv__(self, other):
        return anp.divide(other, self)

    def __rmod__(self, other):
        return anp.mod(other, self)

    def __rtruediv__(self, other):
        return anp.true_divide(other, self)

    def __rmatmul__(self, other):
        return anp.matmul(other, self)

    def __eq__(self, other):
        return anp.equal(self, other)

    def __ne__(self, other):
        return anp.not_equal(self, other)

    def __gt__(self, other):
        return anp.greater(self, other)

    def __ge__(self, other):
        return anp.greater_equal(self, other)

    def __lt__(self, other):
        return anp.less(self, other)

    def __le__(self, other):
        return anp.less_equal(self, other)

    def __abs__(self):
        return anp.abs(self)

    def __hash__(self):
        return id(self)


ArrayBox.register(np.ndarray)
for type_ in [
    float,
    np.longdouble,
    np.float64,
    np.float32,
    np.float16,
    complex,
    np.clongdouble,
    np.complex64,
    np.complex128,
]:
    ArrayBox.register(type_)

# These numpy.ndarray methods are just refs to an equivalent numpy function
nondiff_methods = [
    "all",
    "any",
    "argmax",
    "argmin",
    "argpartition",
    "argsort",
    "nonzero",
    "searchsorted",
    "round",
]
diff_methods = [
    "clip",
    "compress",
    "cumprod",
    "cumsum",
    "diagonal",
    "max",
    "mean",
    "min",
    "prod",
    "ptp",
    "ravel",
    "repeat",
    "reshape",
    "squeeze",
    "std",
    "sum",
    "swapaxes",
    "take",
    "trace",
    "transpose",
    "var",
]
for method_name in nondiff_methods + diff_methods:
    setattr(ArrayBox, method_name, anp.__dict__[method_name])

# Flatten has no function, only a method.
setattr(ArrayBox, "flatten", anp.__dict__["ravel"])

if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
    SequenceBox.register(np.linalg._linalg.EigResult)
    SequenceBox.register(np.linalg._linalg.EighResult)
    SequenceBox.register(np.linalg._linalg.QRResult)
    SequenceBox.register(np.linalg._linalg.SlogdetResult)
    SequenceBox.register(np.linalg._linalg.SVDResult)
elif np.__version__ >= "1.25":
    SequenceBox.register(np.linalg.linalg.EigResult)
    SequenceBox.register(np.linalg.linalg.EighResult)
    SequenceBox.register(np.linalg.linalg.QRResult)
    SequenceBox.register(np.linalg.linalg.SlogdetResult)
    SequenceBox.register(np.linalg.linalg.SVDResult)
